from typing import Dict, List, Tuple, Any
import copy
import sys
import json
import numpy as np
from pathlib import Path

from qa.table_bert.config import TableBertConfig, BERT_CONFIGS
from qa.table_bert.table_bert import TableBertModel
from qa.table_bert.vanilla_table_bert import VanillaTableBert



def get_table_bert_model(config: Dict, use_proxy=False, master=None):
    model_name_or_path = config.get('table_bert_model_or_config')
    if model_name_or_path in {None, ''}:
        model_name_or_path = config.get('table_bert_config_file')
    if model_name_or_path in {None, ''}:
        model_name_or_path = config.get('table_bert_model')

    table_bert_extra_config = config.get('table_bert_extra_config', dict())

    # print(f'Loading table BERT model {model_name_or_path}', file=sys.stderr)
    model = TableBertModel.from_pretrained(
        model_name_or_path,
        **table_bert_extra_config
    )

    print('Table Bert Config')
    print(json.dumps(vars(model.config), indent=2))
    """
    Table Bert Config
        {
          "base_model_name": "bert-base-uncased",
          "column_delimiter": "[SEP]",
          "context_first": true,
          "column_representation": "mean_pool_column_name",
          "max_cell_len": 5,
          "max_sequence_len": 512,
          "max_context_len": 256,
          "do_lower_case": true,
          "cell_input_template": [
            "column",
            "|",
            "type",
            "|",
            "value"
          ],
          "masked_context_prob": 0.15,
          "masked_column_prob": 0.2,
          "max_predictions_per_seq": 100,
          "context_sample_strategy": "nearest",
          "table_mask_strategy": "column",
          "vocab_size": 30522,
          "hidden_size": 768,
          "num_hidden_layers": 12,
          "num_attention_heads": 12,
          "hidden_act": "gelu",
          "intermediate_size": 3072,
          "hidden_dropout_prob": 0.1,
          "attention_probs_dropout_prob": 0.1,
          "max_position_embeddings": 512,
          "type_vocab_size": 2,
          "initializer_range": 0.02
        }
    """
    # print(json.dumps(vars(model.config), indent=2), file=sys.stderr)

    return model


def get_table_bert_input_from_context(
    env_context: List[Dict],
    bert_model: TableBertModel,
    is_training: bool,
    **kwargs
):
    contexts = []
    tables = []

    content_snapshot_strategy = kwargs.get('content_snapshot_strategy', None)
    if content_snapshot_strategy:
        assert content_snapshot_strategy in ('sampled_rows', 'synthetic_row')

    for e in env_context:
        contexts.append(e['question_tokens'])
        tables.append(e['table'])

    return contexts, tables
